{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RNN-循环神经网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\ProgramData\\Anaconda3\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.python import debug as tf_debug\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 超参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_steps = 10\n",
    "batch_size = 200\n",
    "num_classes = 2\n",
    "state_size = 16\n",
    "learning_rate = 0.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 生成数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "就是按照文章中提到的规则，这里生成1000000个\n",
    "'''\n",
    "def gen_data(size=1000000):\n",
    "    X = np.array(np.random.choice(2, size=(size,)))\n",
    "    Y = []\n",
    "    '''根据规则生成Y'''\n",
    "    for i in range(size):   \n",
    "        threshold = 0.5\n",
    "        if X[i-3] == 1:\n",
    "            threshold += 0.5\n",
    "        if X[i-8] == 1:\n",
    "            threshold -=0.25\n",
    "        if np.random.rand() > threshold:\n",
    "            Y.append(0)\n",
    "        else:\n",
    "            Y.append(1)\n",
    "    return X, np.array(Y)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 生成batch数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_batch(raw_data, batch_size, num_step):\n",
    "    raw_x, raw_y = raw_data\n",
    "    data_length = len(raw_x)\n",
    "    batch_patition_length = data_length // batch_size                         # ->5000\n",
    "    data_x = np.zeros([batch_size, batch_patition_length], dtype=np.int32)    # ->(200, 5000)\n",
    "    data_y = np.zeros([batch_size, batch_patition_length], dtype=np.int32)    # ->(200, 5000)\n",
    "    '''填到矩阵的对应位置'''\n",
    "    for i in range(batch_size):\n",
    "        data_x[i] = raw_x[batch_patition_length*i:batch_patition_length*(i+1)]# 每一行取batch_patition_length个数，即5000\n",
    "        data_y[i] = raw_y[batch_patition_length*i:batch_patition_length*(i+1)]\n",
    "    epoch_size = batch_patition_length // num_steps                           # ->5000/5=1000 就是每一轮的大小\n",
    "    for i in range(epoch_size):   # 抽取 epoch_size 个数据\n",
    "        x = data_x[:, i * num_steps:(i + 1) * num_steps]                      # ->(200, 5)\n",
    "        y = data_y[:, i * num_steps:(i + 1) * num_steps]\n",
    "        yield (x, y)    # yield 是生成器，生成器函数在生成值后会自动挂起并暂停他们的执行和状态（最后就是for循环结束后的结果，共有1000个(x, y)）\n",
    "def gen_epochs(n, num_steps):\n",
    "    for i in range(n):\n",
    "        yield gen_batch(gen_data(), batch_size, num_steps)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义placeholder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = tf.placeholder(tf.int32, [batch_size, num_steps], name=\"x\")\n",
    "y = tf.placeholder(tf.int32, [batch_size, num_steps], name='y')\n",
    "init_state = tf.zeros([batch_size, state_size])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RNN输入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_one_hot = tf.one_hot(x, num_classes)\n",
    "rnn_inputs = tf.unstack(x_one_hot, axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义RNN cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From D:\\ProgramData\\Anaconda3\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use the retry module or similar alternatives.\n"
     ]
    }
   ],
   "source": [
    "cell = tf.contrib.rnn.BasicRNNCell(num_units=state_size)\n",
    "rnn_outputs, final_state = tf.contrib.rnn.static_rnn(cell=cell, inputs=rnn_inputs, \n",
    "                                                    initial_state=init_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预测，损失，优化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.variable_scope('softmax'):\n",
    "    W = tf.get_variable('W', [state_size, num_classes])\n",
    "    b = tf.get_variable('b', [num_classes], initializer=tf.constant_initializer(0.0))\n",
    "logits = [tf.matmul(rnn_output, W) + b for rnn_output in rnn_outputs]\n",
    "predictions = [tf.nn.softmax(logit) for logit in logits]\n",
    "\n",
    "y_as_list = tf.unstack(y, num=num_steps, axis=1)\n",
    "losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label,logits=logit) for logit, label in zip(logits, y_as_list)]\n",
    "total_loss = tf.reduce_mean(losses)\n",
    "train_step = tf.train.AdagradOptimizer(learning_rate).minimize(total_loss)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "epoch 0\n",
      "第 100 步的平均损失 0.5067515727877617\n",
      "第 200 步的平均损失 0.48099254071712494\n",
      "第 300 步的平均损失 0.476533437371254\n",
      "第 400 步的平均损失 0.4758702477812767\n",
      "\n",
      "epoch 1\n",
      "第 100 步的平均损失 0.47936957120895385\n",
      "第 200 步的平均损失 0.471754455268383\n",
      "第 300 步的平均损失 0.46945293366909024\n",
      "第 400 步的平均损失 0.46793452560901644\n",
      "\n",
      "epoch 2\n",
      "第 100 步的平均损失 0.4710334450006485\n",
      "第 200 步的平均损失 0.46511719256639483\n",
      "第 300 步的平均损失 0.4647247526049614\n",
      "第 400 步的平均损失 0.46362601190805436\n",
      "\n",
      "epoch 3\n",
      "第 100 步的平均损失 0.4690332645177841\n",
      "第 200 步的平均损失 0.46084885895252226\n",
      "第 300 步的平均损失 0.46341545045375826\n",
      "第 400 步的平均损失 0.46005783438682557\n",
      "\n",
      "epoch 4\n",
      "第 100 步的平均损失 0.4685878440737724\n",
      "第 200 步的平均损失 0.4614964243769646\n",
      "第 300 步的平均损失 0.4612439852952957\n",
      "第 400 步的平均损失 0.459266314804554\n",
      "0.5067515727877617\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3Xl8VOXZ//HPlcm+A0lYQiAJSxAQBAKC7BVZqg+0gopaFDdca7VWxae7tr+61KW2PCoqaCt1wa0UUQQFBBVIQPY1JAFCWEICSSAJ2e7fHzPBGBIyJJOZyZzr/XrlxSznzLlymHznzH3uc99ijEEppZQ1+Hm6AKWUUu6joa+UUhaioa+UUhaioa+UUhaioa+UUhaioa+UUhaioa+UUhaioa+UUhaioa+UUhbi7+kC6oqJiTGJiYmeLkMppVqVDRs2HDfGxDa2nNeFfmJiIunp6Z4uQymlWhUR2e/Mctq8o5RSFqKhr5RSFqKhr5RSFqKhr5RSFqKhr5RSFqKhr5RSFqKhr5RSFuIzoX/oZCnPfr6bA/klni5FKaW8ls+EflFpBX//MoPNOSc9XYpSSnktnwn9pJgwALKOn/ZwJUop5b18JvSDA2zER4eQmXfK06UopZTX8pnQB/vRvh7pK6VUw3wu9DOPn8YY4+lSlFLKK/lc6BeXVZJ/utzTpSillFfyqdBPjrWfzM3M0yYepZSqj2+Ffkw4AFnH9WSuUkrVx6dCP75NCAE2IVNP5iqlVL18KvRtfkLXdmFkafOOUkrVy6dCHyBZu20qpVSDfC70k2LD2J9fQlW1dttUSqm6fC70k2PCKK+q5tCJUk+XopRSXsfnQj/J0YMnU3vwKKXUOXwu9Gv66mu7vlJKncvnQr9dWCARwf56gZZSStXD50JfRLQHj1JKNcDnQh90tE2llGqIT4Z+cmw4h06WUlZR5elSlFLKq/hk6OssWkopVT8NfaWUshANfaWUshCfDP2wIH86RAZrt02llKrDJ0MfaqZO1KtylVKqNt8N/VjttqmUUnU5FfoiMlFEdotIhojMruf5mSKSJyKbHD+313ruZhHZ6/i52ZXFn09yTBgnSyo4ofPlKqXUWf6NLSAiNmAOcAWQA6SJyCJjzI46i75rjLmvzrptgd8DqYABNjjWPeGS6s/j7Hy5x08zKCywpTenlFKtgjNH+kOADGNMpjGmHHgHmOLk608AlhljChxBvwyY2LRSL8zZ0TbztF1fKaVqOBP68cDBWvdzHI/VNVVEtojI+yKScCHrisgsEUkXkfS8vDwnSz+/zm1C8PcTbddXSqlanAl9qeexutNS/RdINMb0A5YDb17Auhhj5hpjUo0xqbGxsU6U1LgAmx9d2oZq6CulVC3OhH4OkFDrfmcgt/YCxph8Y8wZx91XgUHOrtuSkrUHj1JK/YAzoZ8G9BCRJBEJBKYDi2ovICIda92dDOx03F4KjBeRNiLSBhjveMwtakbbrNb5cpVSCnCi944xplJE7sMe1jZgnjFmu4g8DqQbYxYB94vIZKASKABmOtYtEJEnsH9wADxujClogd+jXkkx4ZyprCa3sJTObULdtVmllPJajYY+gDFmCbCkzmO/q3X7MeCxBtadB8xrRo1NVnsMHg19pZTy4StyQefLVUqpunw69OMigggLtOnAa0op5eDToS8iJMWGkalH+kopBfh46IP9ZG6WjraplFKAJUI/jJwTpZyp1PlylVLK50O/W2wYxsD+/BJPl6KUUh7n86Ff021TT+YqpZQFQj9R58tVSqmzfD70I4MDiAkP0pO5SimFBUIf7BdpafOOUkpZJfRjdLRNpZQCi4R+UkwY+afLKSyp8HQpSinlUZYJfYCsfD3aV0pZmyVCPznWPl+unsxVSlmdJUK/S9tQ/ET76iullCVCP9Dfj4S2oTrwmlLK8iwR+uCYOlGP9JVSFmeZ0E+OCSfr+GmM0flylVLWZZnQT4oNo7SiiiNFZZ4uRSmlPMYyoZ9c021Tm3iUUhZmmdA/O9qmnsxVSlmYZUK/Q2QwIQE2HY5BKWVplgl9Pz8hMSaMzDy9QEspZV2WCX3QgdeUUspSoZ8UE8bBE6WUV1Z7uhSllPIIS4V+cmwYVdWGgyd0vlyllDVZKvR1vlyllNVZMvR1tE2llFVZKvSjQwNpGxaoJ3OVUpZlqdAH+9G+Nu8opazKcqGfHBOmV+UqpSzLcqGfFBtGXvEZist0vlyllPU4FfoiMlFEdotIhojMPs9y00TEiEiq436giMwXka0isllExrio7iarGXgt+7h221RKWU+joS8iNmAOMAnoDVwvIr3rWS4CuB9YV+vhOwCMMRcDVwDPiohHv10kxdjny83UHjxKKQtyJoCHABnGmExjTDnwDjClnuWeAJ4Gag9Y3xv4AsAYcww4CaQ2q+Jm6touFNH5cpVSFuVM6McDB2vdz3E8dpaIDAASjDGL66y7GZgiIv4ikgQMAhKaUW+zBQfYiI8O0W6bSilL8ndiGannsbNzDjqaa54HZtaz3DzgIiAd2A98A1SeswGRWcAsgC5dujhRUvMk6cBrSimLcuZIP4cfHp13BnJr3Y8A+gIrRSQbGAosEpFUY0ylMeZBY8wlxpgpQDSwt+4GjDFzjTGpxpjU2NjYpv4uTqsZbVPny1VKWY0zoZ8G9BCRJBEJBKYDi2qeNMYUGmNijDGJxphEYC0w2RiTLiKhIhIGICJXAJXGmB2u/zUuTHJsOKfOVJJXfMbTpSillFs12rxjjKkUkfuApYANmGeM2S4ijwPpxphF51k9DlgqItXAIWCGK4purtpTJ8ZFBnu4GqWUch9n2vQxxiwBltR57HcNLDum1u1sIKXp5bWM7wdeO83Q5HYerkYppdzHclfkAnSKDiHQ309P5iqlLMeSoW/zE5La6Xy5SinrsWTog2O0TT3SV0pZjHVDPzaMA/klVFbpfLlKKeuwbujHhFFZbcg5UerpUpRSym0sG/rdYmu6bWq7vlLKOiwb+mdH29SB15RSFmLZ0G8TGkBUSIB221RKWYplQ19EdOA1pZTlWDb0AZJjNfSVUtZi7dCPCeNwYRkl5eeM9qyUUj7J0qFfczJXj/aVUlZh8dD/fuA1pZSyAg19IEu7bSqlLMLSoR8SaKNTVLCOwaOUsgxLhz7Yx+DR0FdKWYWGfkwYWXmndL5cpZQlaOjHhFNUVknB6XJPl6KUUi3O8qGfHPv9fLlKKeXrNPS1B49SykIsH/rx0SEE2ESP9JVSlmD50Pe3+dGlbShZOq6+UsoCLB/6AMmx4TquvlLKEjT0sbfr788voapau20qpXybhj72vvrlVdXkntT5cpVSvk1Dn+/H4NGTuUopX6ehj71NHyAzT0/mKqV8m4Y+EBMeSESQvw6xrJTyeRr6OObL1akTlVIWoKHvkBQTpt02lVI+T0PfITkmnNzCUsoqqjxdilJKtRgNfYek2DCMgex8PdpXSvkuDX2HXh0iAHhr7X4dW18p5bOcCn0RmSgiu0UkQ0Rmn2e5aSJiRCTVcT9ARN4Uka0islNEHnNV4a7Ws30Ed4xM4q21B5j3dbany1FKqRbh39gCImID5gBXADlAmogsMsbsqLNcBHA/sK7Ww9cAQcaYi0UkFNghIm8bY7Jd9Qu40mOTLuJgQSl/+mQHnduEMKFPB0+XpJRSLuXMkf4QIMMYk2mMKQfeAabUs9wTwNNAWa3HDBAmIv5ACFAOFDWv5Jbj5yc8f90l9OsczS/e+Y7NB096uiSllHIpZ0I/HjhY636O47GzRGQAkGCMWVxn3feB08Bh4ADwV2NMQdPLbXkhgTZeuymVmPAgbnsznYMFJZ4uSSmlXMaZ0Jd6Hjt7plNE/IDngYfqWW4IUAV0ApKAh0Qk+ZwNiMwSkXQRSc/Ly3Oq8JYUGxHEG7cMpryyilvfSKOwtMLTJSmllEs4E/o5QEKt+52B3Fr3I4C+wEoRyQaGAoscJ3NvAD4zxlQYY44BXwOpdTdgjJlrjEk1xqTGxsY27Tdxse5xEbw8YxDZ+ae5Z8EGyiurPV3SeZVVVLFg3X4KS/QDSinVMGdCPw3oISJJIhIITAcW1TxpjCk0xsQYYxKNMYnAWmCyMSYde5POj8QuDPsHwi6X/xYt5LJuMfzl6n58nZHPrz/a6rVdOYvLKpg5fz2//mgbT37WanavUsoDGg19Y0wlcB+wFNgJvGeM2S4ij4vI5EZWnwOEA9uwf3jMN8ZsaWbNbjVtUGfu/1F3Fm7IYc6KDE+Xc45jxWVc98pa0rNPMKBLNAvTD+p5CKVUgxrtsglgjFkCLKnz2O8aWHZMrdunsHfbbNUevKInBwpK+Ovne0hoG8qUS+IbX8kN9uefZsbr68krPsNrN6eS0iGC0U+vZM6KDJ6c2s/T5SmlvJBekesEEeGpaf0YktSWhxduIS3b8x2Qth0qZOpL31JcVsG/77iUMSlxdIwK4fohCSzckMOBfD3aV0qdS0PfSUH+NubOGETnNiHc8c90jw7D/M2+40yfu5ZAm7DwrssY0KXN2efuGdsdm5/w9y/3eqw+pZT30tC/ANGhgcy/ZTB+Itwyfz0Fp8vdXsOSrYeZOS+NjlHBfHDPZXSPC//B8+0jg7nx0i58+N0hsnV+AKVUHRr6F6hruzBevWkQuYVlzPpnuluHYn5r7X7u/fdGLu4cxcK7htExKqTe5e4e040Am/CiHu0rperQ0G+CQV3b8vy1l5C+/wQPv7+F6uqW7cppjOGF5Xv4zcfbGJsSx1u3XUp0aGCDy8dFBPOzS7vy8XeHdN5fpdQPaOg30ZX9OvLoxF78d3Muzy7b3WLbqao2/PY/23hh+V6mDuzMKzMGERJoa3S9O0d3I9Dfjxe/0KN9pdT3NPSb4a7RyUwfnMCcFft4L+1g4ytcoDOVVfz87Y28tfYAd45O5q/X9CPA5tx/WWxEEDcPS2TR5lwyjunRvlLKTkO/GUSEJ37Sl5E9Yvjfj7ayZu9xl712cVkFM+elsWTrEX5z5UU8NukiROobBqlhs0YlExxg06N9pdRZGvrNFGDzY86NA+kWG87db23gD4u28/b6A2zYf4LisqaNg5NXfIbpc9eSll3A89f15/aR54xR55R24UHcNCyR/27JZc/R4ia9hlLKt4i3jSeTmppq0tPTPV3GBTt0spSH3tvElpxCSsq/79ETHx1Cz/bh9OwQQa8OEfRsH0G32HCCA+pvlz+QX8KMees4VnSG//vZQMamxDWrroLT5Yx86kvG9Ipjzg0Dm/VaSinvJSIbjDHnDGhZl1PDMKjGxUeH8M6sYVRXGw6dLGX3kWJ2Hy1mz9Fidh8pZk3GcSqq7B+wfgKJMWGktLd/CKQ4PgxKy6u45Y00KqurWXDHpQysddFVU7UNC2Tm8ET+b+U+dh0poleHyGa/plKq9dIjfTepqKpmf/5pdh85Zf8wOGL/QMjOP03tHp+dooL5521D6B4X4bJtnywpZ8RTKxjZI4aXfjbIZa+rlPIeeqTvZQJsfnSPi6B7XARX0vHs42UVVWQcO8XuI8UcKSrj6oHxDV501VTRoYHcOjyRF7/MYEduEb076dG+UlalJ3I9LDjARt/4KKYO6sy9Y7u7PPBr3DYimYhgf/72xZ4WeX2lVOugoW8RUaEB3DYiiaXbj7LtUKGny1FKeYiGvoXcOiKJyGB/Xliu/faVsioNfQuJDA7g9pHJLN95lC05Jz1djlLKAzT0LeaW4YlEhQTo0b5SFqWhbzERwQHMGpXMl7uOsemgHu0rZTUa+hZ082WJtAkN4Pll2pNHKavR0Leg8CB/Zo3qxqo9eWzYf8LT5Sil3EhD36JuGtaVtmGBvLBcj/aVshINfYsKC/LnzlHJrN57nPTsAk+Xo5RyEw19C5sxrCsx4YE8r0f7SlmGhr6FhQb6c9fobnydkc+6zHxPl6OUcgMNfYu78dKuxEYE6dG+UhahoW9xIYE27h7djbWZBXyzz3XTPSqlvJOGvuKGS7sQFxHEC8v24m3zKyilXEtDXxEcYOPesd1Zn13AK19lUllV7emSlFItRENfATB9SAJjU2J58tNdXPX3NazP0m6cSvkiDX0FQJC/jXkzB/PyzwZSXFbJta98y4PvbuJYUZmnS1NKuZCGvjpLRJjYtyPLfzma+8Z255Mth/nRs6t4bXUmFV7e5JN9/DS3vZHGmr16Mlqp89HQV+cICbTxqwkpLH1wFKmJbfjTJzu56sU1rPXSvvyr9uQx+R9r+GLXMX753iYKSys8XZJSXsup0BeRiSKyW0QyRGT2eZabJiJGRFId928UkU21fqpF5BJXFa9aVlJMGPNnDmbujEGcOlPJ9Llr+cU733HUS5p8jDG8vGoft8xfT6foEF66cSD5p8v5f5/s9HRpSnkt/8YWEBEbMAe4AsgB0kRkkTFmR53lIoD7gXU1jxljFgALHM9fDPzHGLPJdeWrliYijO/TgZE9YnlpZQYvf5XJ8h1HeWBcT2YOTyTA5pkviyXllTzy/hYWbznMlf068sy0foQG+nNHTiEvr9rH5Es6Mbx7jEdqU8qbOfMXOwTIMMZkGmPKgXeAKfUs9wTwNNDQYeD1wNtNqlJ5XEigjV+OT2HZg6O4NLkdf16ykx//bbVHLug6WFDC1Je+5ZOth3l0Yi/+cf0AQgPtxy8PjOtBckwYsz/cQkl5pdtrU8rbORP68cDBWvdzHI+dJSIDgARjzOLzvM51NBD6IjJLRNJFJD0vL8+JkpSndG0XxryZg3ntplTKKqu44dV13PfvjRwuLHXL9r/JOM7kf6wh50QJ82cO5u4x3RCRs88HB9h4cmo/DhaU8telOrSEUnU5E/pSz2NnL9sUET/geeChBl9A5FKgxBizrb7njTFzjTGpxpjU2NhYJ0pSnjaud3uWPTiaB8b1YNmOo1z+7CpeXrWPsoqqFtmeMYbX12QxY9562oUHsei+EYxJiat32SFJbblpWFfmf5Olk8QoVYczoZ8DJNS63xnIrXU/AugLrBSRbGAosKjmZK7DdLRpx+cEB9h4YFxPlj04msu6tePJT3cx5M/L+f1/trHzcJHLtlNWUcVDCzfzxOIdXN4rjo/vHU5STNh513lkYi86RYXw6AdbOFPZMh9ESrVG0thYKyLiD+wBLgcOAWnADcaY7Q0svxL4lTEm3XHfDzgAjDLGZDZWUGpqqklPT7+Q30F5ibWZ+fx73QE+23aE8qpq+idEc/3gBP6nfyfCghrtM1Cv3JOl3PXWBrbkFPLguJ78/Efd8fOr78vnuVbuPsbM+Wn8/EfdeWh8SpO235p8vv0I/ROiaR8Z7OlSlAeIyAZjTGpjyzX6l2iMqRSR+4ClgA2YZ4zZLiKPA+nGmEWNvMQoIMeZwFet29DkdgxNbseJ0+V8+N0h3ll/gNkfbuWJxTuYfEknpg/uQr/OUT9ogz+ftOwC7n5rA2UV1bx6UypX9G5/QfWMSYnj6oHxvLRyH5P6dqR3p8im/FqtwpKth7lnwUZG94zlzVuHeLoc5cUaPdJ3Nz3S9x3GGDYeOMHb6w+yeEsuZRXVXNQxkhuGJDBlQDyRwQENrrdg3QH+sGg7CW1DefWmQXSPi2hSDSdLyhn33Co6RoXw0T2X4e+hLqYt6XBhKRNfWM2ZyirKKqpZ/PMR9I2P8nRZbvfG11mkdIhkWLd2ni7FI5w90ve9vwDlNUSEQV3b8tdr+rP+1+N44id9EeC3/9nOkD8v56H3NpOeXfCD4ZzPVFbxvx9t5Tcfb2NEjxg+vnd4kwMfIDo0kMen9GXroUJeXZ3lgt/Ku1RXG3757mbKK6t5d9YwwoP8eXnVPk+X5XYZx4r5w3938NB7m1qsM4GvaFpDq1IXKDI4gBlDu/KzS7uw9VAhb68/yKJNh/hgYw7d48KZPjiBkT1ieezDLWw8cJJ7x3bjl1ekYHOy/f58fnxxRyb26cDzy/cwoU97kmPDXfAbeYdXV2fybWY+T029mP4J0dw4tAuvfpVJ9vHTJDZystuXvLY6C38/IbewjDe/yebO0d08XZLX0iN95VYiQr/O0fzl6otZ/+txPDX1YsKD/PnTJzuZ8MJX7DxczJwbBvLwhF4uCfwaj0/pQ7C/H49+sIXqau9q0myqbYcK+evnu5nUtwPXpto72N02PAl/mx+vfGWdU2h5xWf48LtDXDs4gdE9Y5mzIoOTJeWeLstraegrjwkL8ue6wV34+N7hfPbASB6ekMLH9w7nyn4dXb6tuMhgfntVb9KyT7Bg3X6Xv767lZRXcv8739EuLIi/XH3x2ZPjcZHBTBvUmQ825FhmWOx/rd1PeWU1t41IYvakXhSfqeSlldZr4nKWhr7yCr06RHLv2O6kdGh6+31jpg3qzMgeMTz56S5yTpS02Hbc4U+f7CTr+Gmeu7Y/0aGBP3juzlHJVFZX8/rXvncOo67S8ir+9W024y5qT7fYcC7qGMlPB8Qz/5tsDp10z1XirY2GvrIMEeH//fRiDPDrj7a12vmAP99+hH+vO8CskclcVs+gcl3bhXFlv04sWHvA54eZ/mBjDidKKrhjZNLZx2quyXjucx2Goz4a+spSEtqG8siEFFbtyePDjYc8Xc4FO1ZUxqMfbKFPp8jzXnB21+hkTp2p5K21rb8pqyHV1fahOfp3jmJIUtuzj8dHhzDzskQ+/C7HpVeG+woNfWU5Nw1LJLVrGx5fvIO84jOeLsdp1dWGhxZuprSiir9NH0Cgf8N/vn06RTEmJZZ5a7IoLffNLozLdx4l6/hpbh+ZfM4Ff/eM6UZEkD9PfbbLQ9V5Lw19ZTl+fsKTU/tRWlHF7xfVOwagV5r/TTar9x7nt1f1pntc491O7x7djfzT5SzccLDRZVujV1dnEh8dwqS+Hc55Ljo0kHvHdmfl7jyPDP/tzTT0lSV1jwvnF5f3YMnWI3y27bCny2nUjtwinvp0F+Muas8NQ7o4tc6QpLYM7BLNK6u8f47jC/XdgROkZZ/g1hFJDV5lffNliXSKCubJT3f5TDddV9DQV5Y1a1QyvTtG8tv/bKewxHtPeJZVVPGLd74jKjSAp6Ze7PTYRSLCPWO6c+hkKYu35Da+Qivy2uosIoL9uW5wQoPLBAfYJ/7ZklPIJ1u9/4PdXTT0lWUF2Px4elo/Ck6X86dPdjS+gof8ZclO9h47xbPX9KddeNAFrfujXnH0bB/OSyv3+czR7sGCEj7ddpgbLu1CeCOjt/50QDy9OkTwzNLdlFf61redptLQV5bWNz6KO0cls3BDDl/t8b5Z21bsOsab3+7nthFJjOp54RMM+fkJd43uxp6jp1ix+1gLVOh+r6/Jwk+EWy5LanRZm5/w6KReHCgo4d8+cFGeK2joK8u7//IeJMeG8diHWzmQX+I17d95xWd4+P3N9OoQwcMTmj4fwP/070R8dAj/t3Jfq702oUZhSQXvpR9kcv9OdIhybt6AMT1jGZbcjhe/zKC4zHub8dxFB1xTlhccYOPpqf245pVvGfXMCgBiwgOJiwimfWQQ7SODiYu03+4QGey4H0S7sCCXjg9UmzGGR97fTHFZJf++YyjBAbYmv1aAzY9Zo5L5/aLtpGWf+EGf9tZmwfr9lJRXcfvIZKfXERFmT+rFlDlfM/erTEtMqHM+GvpKAamJbfnvfSPYdqiQo0VnOFpcxrGiMo4UlbEtt4jjp85Q9yDZ5ifEhgfRPjLo7IfCJQltGN0zltiIC2t7r+uf3+5nxe48/ji5Dz3bN39oimtTE3jxi728tDKDIUmtc5KV8spq3vg6mxHdYy54Qpz+CdFc2a8jr63OYsbQrsRZeHYxDX2lHPrGRzU4+UhlVTXHT5VztKjM/lN8hqOF398+WFDC2sx83lp7wPFakYxNiWNMSiyXJLS5oG8Ee44W8+clOxmbEstNw7q65HcLCbRxy/BE/vr5HnbkFrXKWcQWbc7lWPEZnrmmf5PWf3h8Cku3HeH55Xv5y9UXu7i61kNDXykn+Nv86BAVfN525Opqw47DRazak8fK3ceYsyKDv3+ZQVRIACN7xDAmJa7RbwFlFVXc//Z3RAb78/S0/k53z3TGjKGJvLRyHy+v2seL1w9w2eu6gzGG11ZnktI+glE9zh1vyBmJMWHceGkX3lp3gNtGJDl1gZsv0tBXykX8/OTst4V7x3ansKSC1Rl5rNydx6o9eSzeYu8r3jc+kjE9a74FRP/g4qJnlu5m15Fi5s8c3OwmorqiQgO4cWhXXludya/Gp9ClXahLX78lrd57nF1HinlmWr9mfRD+/PIevL8hh6c/28XcmxqdWdAnaegr1UKiQgO4ql8nrurX6ZxvAS+t2sc/VvzwW0Cgvx+vr8ni5mFdGdsrrkVqum1EEm98nc0rX+3jzz9tPU0cr67OJDYiiMmXdGrW68SEB3Hn6G48t2wP6dkFpCa23pPaTaWhr5Qb1PctYE3GcVbuPsbKWt8CesSF89iPL2qxOtpHBjN1UDwLN+Twi3E9iItw3QnNsooqgvz9XNokBbDzcBGr9x7n4QkpBPk3vRdTjdtHJvGvtft58tNdLLxrmMvr9XbaT18pD4gKDeDKfh155pr+rP/fy/nk/hH85sqLePWm1GZ1z3TGnaO6UVlVzbw12S55PWMM76YdYPCfljPj9fUUubgv/GurswgJsHHjpc6NOdSY0EB/HhjXg/T9J1i246hLXrM10dBXysNEhD6dorh9ZLJbJjNPjAlj0sUdWbB2f7MDOvdkKTfPT+PRD7aSGBPG2sx8rnnpW3JdNGvV0aIyFm0+xLWpnc+ZIaw5rktNIDk2jKc+20Wll1yM5y4a+kpZ0N2ju1HcjElWao7uJzz/FWlZBTw+pQ//uXc4b9wyhNyTpfz0/75me25hs+t845tsqqoNt45ofMiFC+Fv8+ORCb3Yl3eahRtyXPra3k5DXykL6hsfxaie9klWyioubJKV3JOlzHQc3ffuFMnSB0Zx07BE/PyEET1iWHj3MPxEuPblb1nZjPF+Tp+pZMHa/Uzo04Gu7Vz/DWhCn/YM6tqG55ftoaS80uWv76009JWyqLtHd+P4qXKnj3RrH92vdxzdv33H0HO6fvbqEMlH9wynS7swbnsznXfWH2hSfe/qFqWDAAALDUlEQVSlH6SorJI7Rjk/5MKFEBEem9SLY8VnmLfG9yeRr6Ghr5RFDU1uyyUJ0cz9al+j7dqHCxs+uq9Ph6hgFt41jOHdY5j94VaeWbrrggZ7q6yq5vU1WQzq2oaBXdpc0O91IVIT23JF7/a8vCqT/FOtZ+rM5tDQV8qi7JOsdONgQWmDk4wYY3gv7SDjn7Mf3f9xcv1H9/UJD/Ln9ZtTmT44gTkr9vHgu5s4U+lcU9LS7UfJOVHKHRcwsFpTPToxhZLySv7+ZUaLb8sbaOgrZWHjLmpP9zj7JCt1j8Rrju4f+WALvTtF8tkDI7n5soaP7usTYPPjL1dfzMMTUvh4Uy43z1vf6Cxlxhjmrs6ka7tQrujdvkm/14XoHhfBdYMTWLBuPzsPF3nN0NotRS/OUsrCaiZZ+dXCzazcncfYXnEYY1iYnsMTi3dQWW344+Q+zBja9YLCvjYR4d6x3YmPDuHh9zcz9eVvmD9zMAlt6/+2kL7/BJsPnuSJKX1abOjquh4Y15OPvjvEpL+tBiAs0EZkSABRIQHf/xts/9f+409UaN3HAogJD2ryfnIX8bZJFVJTU016erqny1DKMiqqqhn99Ari24Tw4vUDeOzDrazcncelSW15elo/l/ac+XZfPnf+K51AfxvzZqbSr3P0OcvM+mc667ML+Hb25YQEtuyFarVtzy0kLauAwtJKCksrKCytoKjM8W/p9/+eLm+4iap7XDgPjOvBj/t2dHv4i8gGY0yjAwpp6CulmP91Fn/87w5CA20YA7Mn9WrW0f357D1azMz5aRScLucfNwzg8ou+b8LJOn6aHz27kvvGdvfayU4qqqrPfgjYPxjsHxLHi8/w9voD7D12il4dInhofArjLopz2zAPLg19EZkI/A2wAa8ZY55sYLlpwEJgsDEm3fFYP+AVIBKodjxX1tC2NPSVcr+S8kqueO4rEtqG8NRU1x7d1+dYcRm3vpHGjtwi/jilLzOG2ucN+M3HW3kvLYc1s8e6dFwgd6mqNvx3cy4vLN9Ddn4J/TtH8cvxKYzqEdPi4e+y0BcRG7AHuALIAdKA640xO+osFwF8AgQC9xlj0kXEH9gIzDDGbBaRdsBJY0yD34809JXyjKpq47Y2dLBffPXzt7/jy13HuHNUMneMSmb4k1/yk0vieWpaP7fV0RIqq6r5cOMh/vbFXg6dLCW1axseGp/CsG7tWmybzoa+M713hgAZxphMY0w58A4wpZ7lngCeBmofxY8HthhjNgMYY/LPF/hKKc9xZ+ADhAX5M3fGIH42tAuvfJXJlH98zZnKam4f6dohFzzB3+bHtYMTWPGrMTzxk74cPFHC9a+u5cbX1rJh/wmP1uZM6McDB2vdz3E8dpaIDAASjDGL66zbEzAislRENorII82qVinlU/xtfjwxpS+zJ/Xi0MlSxqbE0sMFcwJ7i0B/P2YM7cqqh8fy26t6s/tIMVNf+oZb5q9na07zxyZqCme6bNb38X+2TUhE/IDngZkNvP4IYDBQAnzh+AryxQ82IDILmAXQpYtrhk9VSrUOIvZuo5d1a0dCm9Yzm9eFCA6wcduIJK4fksCb3+zn5VX7+J9/rGFCn/Y8eEVPenVw35zFzrTpDwP+YIyZ4Lj/GIAx5i+O+1HAPuCUY5UOQAEwGegOTDTGzHQs+1ugzBjzTEPb0zZ9pZSvKyqrYN6aLF5fncWp8kqu6teJB8b1oFts0+ftdWWbfhrQQ0SSRCQQmA4sqnnSGFNojIkxxiQaYxKBtcBkR++dpUA/EQl1nNQdDew4dxNKKWUdkcEBPDCuJ6sfHcvdo7vxxc6jXPHcKv60uOXjsdHmHWNMpYjchz3AbcA8Y8x2EXkcSDfGLDrPuidE5DnsHxwGWGKM+cRFtSulVKsWHRrIIxN7ceuIJF5eua/Bq5RdSS/OUkopH+DK5h2llFI+QkNfKaUsRENfKaUsRENfKaUsRENfKaUsRENfKaUsRENfKaUsRENfKaUsxOsuzhKRPGB/M14iBjjuonJagtbXPFpf82h9zePN9XU1xsQ2tpDXhX5ziUi6M1eleYrW1zxaX/Nofc3j7fU5Q5t3lFLKQjT0lVLKQnwx9Od6uoBGaH3No/U1j9bXPN5eX6N8rk1fKaVUw3zxSF8ppVQDWmXoi8hEEdktIhkiMrue54NE5F3H8+tEJNGNtSWIyAoR2Ski20XkF/UsM0ZECkVkk+Pnd+6qr1YN2SKy1bH9cyYwELsXHftwi4gMdFNdKbX2yyYRKRKRB+os4/b9JyLzROSYiGyr9VhbEVkmInsd/7ZpYN2bHcvsFZGb3VjfMyKyy/H/95GIRDew7nnfCy1Y3x9E5FCt/8cfN7Duef/eW7C+d2vVli0imxpYt8X3n0sZY1rVD/bZu/YByUAgsBnoXWeZe4CXHbenA++6sb6OwEDH7QhgTz31jQEWe3g/ZgMx53n+x8CngABDgXUe+r8+gr3/sUf3HzAKGAhsq/XY08Bsx+3ZwFP1rNcWyHT828Zxu42b6hsP+DtuP1Vffc68F1qwvj8Av3LiPXDev/eWqq/O888Cv/PU/nPlT2s80h8CZBhjMo0x5cA7wJQ6y0wB3nTcfh+4XETEHcUZYw4bYzY6bhcDO4F4d2zbxaYA/zR2a4FoEeno5houB/YZY5pzsZ5LGGO+AgrqPFz7ffYm8JN6Vp0ALDPGFBhjTgDLgInuqM8Y87kxptJxdy3Q2dXbdVYD+88Zzvy9N9v56nNkx7XA267erie0xtCPBw7Wup/DuaF6dhnHm74QaOeW6mpxNCsNANbV8/QwEdksIp+KSB+3FmZngM9FZIOIzKrneWf2c0ubTsN/aJ7efwDtjTGHwf5hD8TVs4w37EeAW7F/c6tPY++FlnSfo/lpXgPNY96w/0YCR40xext43pP774K1xtCv74i9bhckZ5ZpUSISDnwAPGCMKarz9EbsTRb9gb8DH7uzNofhxpiBwCTgXhEZVed5j+5DEQkEJgML63naG/afs7zhvfhroBJY0MAijb0XWspLQDfgEuAw9iaUujy+/4DrOf9Rvqf2X5O0xtDPARJq3e8M5Da0jIj4A1E07atlk4hIAPbAX2CM+bDu88aYImPMKcftJUCAiMS4qz7HdnMd/x4DPsL+Nbo2Z/ZzS5oEbDTGHK37hDfsP4ejNU1ejn+P1bOMR/ej48TxVcCNxtEAXZcT74UWYYw5aoypMsZUA682sF1P7z9/4Grg3YaW8dT+a6rWGPppQA8RSXIcDU4HFtVZZhFQ00tiGvBlQ294V3O0/70O7DTGPNfAMh1qzjGIyBDs/w/57qjPsc0wEYmouY39hN+2OostAm5y9OIZChTWNGW4SYNHV57ef7XUfp/dDPynnmWWAuNFpI2j+WK847EWJyITgUeBycaYkgaWcea90FL11T5H9NMGtuvM33tLGgfsMsbk1PekJ/dfk3n6THJTfrD3LNmD/az+rx2PPY79zQ0QjL1ZIANYDyS7sbYR2L9+bgE2OX5+DNwF3OVY5j5gO/aeCGuBy9y8/5Id297sqKNmH9auUYA5jn28FUh1Y32h2EM8qtZjHt1/2D+ADgMV2I8+b8N+nugLYK/j37aOZVOB12qte6vjvZgB3OLG+jKwt4fXvA9rerR1Apac773gpvr+5XhvbcEe5B3r1ue4f87fuzvqczz+Rs37rtaybt9/rvzRK3KVUspCWmPzjlJKqSbS0FdKKQvR0FdKKQvR0FdKKQvR0FdKKQvR0FdKKQvR0FdKKQvR0FdKKQv5//g55yCWQiOvAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x192722a5710>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def train_rnn(num_epochs, num_steps, state_size=4, verbose=True):\n",
    "    with tf.Session() as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        #sess = tf_debug.LocalCLIDebugWrapperSession(sess)\n",
    "        training_losses = []\n",
    "        for idx, epoch in enumerate(gen_epochs(num_epochs, num_steps)):\n",
    "            training_loss = 0\n",
    "            training_state = np.zeros((batch_size, state_size))   # ->(200, 4)\n",
    "            if verbose:\n",
    "                print('\\nepoch', idx)\n",
    "            for step, (X, Y) in enumerate(epoch):\n",
    "                tr_losses, training_loss_, training_state, _ = \\\n",
    "                    sess.run([losses, total_loss, final_state, train_step], feed_dict={x:X, y:Y, init_state:training_state})\n",
    "                training_loss += training_loss_\n",
    "                if step % 100 == 0 and step > 0:\n",
    "                    if verbose:\n",
    "                        print('第 {0} 步的平均损失 {1}'.format(step, training_loss/100))\n",
    "                    training_losses.append(training_loss/100)\n",
    "                    training_loss = 0\n",
    "    return training_losses\n",
    "\n",
    "training_losses = train_rnn(num_epochs=5, num_steps=num_steps, state_size=state_size)\n",
    "print(training_losses[0])\n",
    "plt.plot(training_losses)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
